package com.troila.xfspark.websocket;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.troila.xfspark.model.SparkProperties;
import com.troila.xfspark.model.SparkResponse;
import com.troila.xfspark.model.SparkRoleContent;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

import java.io.IOException;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

@Slf4j
@Configuration
public class SparkWebSocketConfig implements WebSocketConfigurer {

    private SparkProperties properties;

    private ObjectMapper objectMapper;

    @Autowired
    public void setProperties(SparkProperties properties) {
        this.properties = properties;
    }

    @Autowired
    public void setObjectMapper(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        try {
            registry.addHandler(buildMessageHandler(), SparkWrapper.getInstance().websocketUrl(properties))
                    .addInterceptors(new HttpSessionHandshakeInterceptor())
                    .setAllowedOrigins("*");
        } catch (Exception e) {
            logger.error(e.getMessage());
        }
    }

    @Bean(name = "sparkWebSocketHandler")
    public WebSocketHandler buildMessageHandler() {
        return new AbstractWebSocketHandler() {
            @Override
            public void afterConnectionEstablished(WebSocketSession session) throws Exception {
                super.afterConnectionEstablished(session);
                WebSocketSessionManager.setSparkSession(session);
                logger.info("[spark websocket] 连接成功1");
            }

            @Override
            public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
                super.afterConnectionClosed(session, status);
                WebSocketSessionManager.setSparkSession(null);
                logger.info("[spark websocket] 退出连接");
            }

            @Override
            protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
                super.handleTextMessage(session, message);
                logger.info("[spark websocket] 收到讯飞消息：{}", message.getPayload());
                try {
                    SparkResponse response = objectMapper.readValue(message.getPayload(), new TypeReference<SparkResponse>() {
                    });
                    if (response.getHeader().getCode() != 0) {
                        logger.error("发生错误，错误码为：{}", response.getHeader().getCode());
                        logger.error("本次请求的sid为：{}", response.getHeader().getSid());
                        session.close();
                    } else {
                        String sid = response.getHeader().getSid();
                        SparkResponse.SparkChoices sparkChoices = response.getPayload().getChoices();
                        WebSocketSessionManager.putChoices(sid, sparkChoices);

                        // 表示最后一个结果
                        if (response.getHeader().getStatus() == 2) {
                            List<SparkResponse.SparkChoices> choicesList = WebSocketSessionManager.getChoices(sid);
                            String content = choicesList.stream()
                                    .sorted(Comparator.comparing(SparkResponse.SparkChoices::getSeq))
                                    .map(SparkResponse.SparkChoices::getText)
                                    .flatMap(Collection::stream)
                                    .map(SparkRoleContent::getContent)
                                    .collect(Collectors.joining(""));
                            WebSocketSessionManager.removeChoices(sid);

                            SparkRoleContent roleContent = SparkRoleContent.builder().content(content).role("assistant").build();
                            WebSocketSessionManager.findAllSessions()
                                    .forEach(socketSession -> {
                                        try {
                                            String msg = objectMapper.writeValueAsString(roleContent);
                                            logger.info("[local websocket] 发送消息：{}", msg);
                                            socketSession.sendMessage(new TextMessage(msg));
                                        } catch (IOException e) {
                                            logger.error(e.getMessage());
                                        }
                                    });

                        }
                    }

                } catch (JsonProcessingException e) {
                    logger.error(e.getMessage());
                } catch (IOException e) {
                    logger.error(e.getMessage());
                }
            }
        };
    }


}
